/*
 * object.c
 *
 * Copyright (C) 2016 Aleksandar Andrejevic <theflash@sdf.lonestar.org>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <object.h>
#include <process.h>
#include <thread.h>
#include <heap.h>
#include <pipe.h>

typedef void (*object_cleanup_proc_t)(object_t*);
typedef dword_t (*object_pre_wait_proc_t)(object_t*, void*, wait_condition_t*);
typedef void (*object_post_wait_proc_t)(object_t*, void*, wait_result_t);

extern void file_cleanup(object_t*);
extern void file_instance_cleanup(object_t*);
extern dword_t semaphore_pre_wait(object_t *obj, void *parameter, wait_condition_t *condition);
extern void semaphore_post_wait(object_t *obj, void *parameter, wait_result_t result);

static DECLARE_LOCK(obj_lock);
static DECLARE_LIST(anonymous_objects);
static list_entry_t named_objects[256];
static list_entry_t objects_by_type[OBJECT_TYPE_MAX];

static struct
{
    object_cleanup_proc_t cleanup;
    object_pre_wait_proc_t pre_wait;
    object_post_wait_proc_t post_wait;
} type_info[OBJECT_TYPE_MAX] =
{
    { .cleanup = file_cleanup,          .pre_wait = NULL,               .post_wait = NULL                },
    { .cleanup = file_instance_cleanup, .pre_wait = NULL,               .post_wait = NULL                },
    { .cleanup = NULL,                  .pre_wait = NULL,               .post_wait = NULL                },
    { .cleanup = pipe_cleanup,          .pre_wait = pipe_pre_wait,      .post_wait = NULL                },
    { .cleanup = process_cleanup,       .pre_wait = process_pre_wait,   .post_wait = NULL                },
    { .cleanup = thread_cleanup,        .pre_wait = thread_pre_wait,    .post_wait = NULL                },
    { .cleanup = memory_cleanup,        .pre_wait = NULL,               .post_wait = NULL                },
    { .cleanup = NULL,                  .pre_wait = semaphore_pre_wait, .post_wait = semaphore_post_wait },
    { .cleanup = NULL,                  .pre_wait = NULL,               .post_wait = NULL                },
};

static inline byte_t get_name_hash(const char *name)
{
    byte_t sum = 0;
    while (*name) sum += *name++;
    return sum;
}

static handle_t insert_object(object_t *obj, access_flags_t access_flags)
{
    handle_t i;
    process_t *proc = get_previous_mode() == USER_MODE ? get_current_process() : kernel_process;

    lock_acquire(&proc->handle_table_lock);

    for (i = 0; i < proc->handle_table_size; i++) if (!proc->handle_table[i].obj)
    {
        proc->handle_table[i].obj = obj;
        proc->handle_table[i].access_flags = access_flags;
        proc->handle_count++;
        goto cleanup;
    }

    handle_info_t *expanded_table = (handle_info_t*)heap_realloc(&evictable_heap, proc->handle_table, proc->handle_table_size * 2);
    if (expanded_table == NULL)
    {
        i = INVALID_HANDLE;
        goto cleanup;
    }

    expanded_table[proc->handle_table_size].obj = obj;
    for (i = proc->handle_table_size + 1; i < proc->handle_table_size * 2; i++) expanded_table[i].obj = NULL;

    proc->handle_table = expanded_table;
    i = proc->handle_table_size;
    proc->handle_table_size *= 2;

cleanup:
    lock_release(&proc->handle_table_lock);
    return i;
}

static bool_t access_check(object_t *obj, access_flags_t access)
{
    if (check_privileges(PRIVILEGE_ACCESS_ALL)) return TRUE;

    list_entry_t *ptr;
    dword_t uid = get_current_uid();

    for (ptr = obj->acl.next; ptr != &obj->acl; ptr = ptr->next)
    {
        access_control_entry_t *ace = CONTAINER_OF(ptr, access_control_entry_t, link);
        if ((ace->uid == ALL_USERS || ace->uid == uid) && access == (access & ace->access_mask)) return TRUE;
    }

    return FALSE;
}

void reference(object_t *object)
{
    lock_acquire(&obj_lock);
    object->ref_count++;
    lock_release(&obj_lock);
}

void dereference(object_t *object)
{
    lock_acquire(&obj_lock);
    dword_t ref_count = --object->ref_count;

    if (!ref_count)
    {
        list_remove(&object->by_name_list);
        list_remove(&object->by_type_list);
    }

    lock_release(&obj_lock);

    if (!ref_count)
    {
        if (type_info[object->type].cleanup != NULL) type_info[object->type].cleanup(object);
        free(object);
    }
}

bool_t reference_by_handle(handle_t handle, object_type_t type, object_t **object)
{
    bool_t result = FALSE;
    process_t *proc = get_previous_mode() == USER_MODE ? get_current_process() : kernel_process;

    if (handle >= proc->handle_table_size) return FALSE;
    lock_acquire(&obj_lock);
    lock_acquire_shared(&proc->handle_table_lock);

    *object = proc->handle_table[handle].obj;
    if (*object != NULL && (type == OBJECT_ANY_TYPE || (*object)->type == type))
    {
        (*object)->ref_count++;
        result = TRUE;
    }

    lock_release(&proc->handle_table_lock);
    lock_release(&obj_lock);
    return result;
}

bool_t reference_by_name(const char *name, object_type_t type, object_t **object)
{
    list_entry_t *ptr;
    bool_t result = FALSE;
    byte_t hash = get_name_hash(name);

    lock_acquire(&obj_lock);

    for (ptr = named_objects[hash].next; ptr != &named_objects[hash]; ptr = ptr->next)
    {
        object_t *obj = CONTAINER_OF(ptr, object_t, by_name_list);

        if ((obj->name != NULL) && (strcmp(obj->name, name) == 0) && (obj->type == type))
        {
            obj->ref_count++;
            *object = obj;
            result = TRUE;
            break;
        }
    }

    lock_release(&obj_lock);
    return result;
}

dword_t grant_access(object_t *obj, dword_t uid, access_flags_t access)
{
    dword_t ret = ERR_SUCCESS;
    reference(obj);
    lock_acquire(&obj->acl_lock);

    if (!access_check(obj, access))
    {
        ret = ERR_FORBIDDEN;
        goto cleanup;
    }

    list_entry_t *ptr;
    bool_t done = FALSE;

    for (ptr = obj->acl.next; ptr != &obj->acl; ptr = ptr->next)
    {
        access_control_entry_t *ace = CONTAINER_OF(ptr, access_control_entry_t, link);
        if (ace->uid == uid && (ace->access_mask & access) == access)
        {
            ace->access_mask |= access;
            done = TRUE;
            break;
        }
    }

    if (!done)
    {
        access_control_entry_t *ace = malloc(sizeof(access_control_entry_t));
        if (ace == NULL)
        {
            ret = ERR_NOMEMORY;
            goto cleanup;
        }

        ace->uid = uid;
        ace->access_mask = access;
        list_append(&obj->acl, &ace->link);
    }

cleanup:
    lock_release(&obj->acl_lock);
    dereference(obj);
    return ret;
}

dword_t revoke_access(object_t *obj, dword_t uid, access_flags_t access)
{
    dword_t ret = ERR_SUCCESS;
    reference(obj);
    lock_acquire(&obj->acl_lock);

    if (get_current_uid() != obj->owner)
    {
        ret = ERR_FORBIDDEN;
        goto cleanup;
    }

    list_entry_t *ptr;

    for (ptr = obj->acl.next; ptr != &obj->acl; ptr = ptr->next)
    {
        access_control_entry_t *ace = CONTAINER_OF(ptr, access_control_entry_t, link);
        if (ace->uid == uid && (ace->access_mask & access) == ace->access_mask)
        {
            if (!(ace->access_mask &= ~access))
            {
                ptr = ptr->prev;
                list_remove(&ace->link);
                free(ace);
                ace = NULL;
            }
        }
    }

cleanup:
    lock_release(&obj->acl_lock);
    dereference(obj);
    return ret;
}

dword_t create_object(object_t *obj)
{
    if (obj->name != NULL)
    {
        object_t *other_obj;

        if (reference_by_name(obj->name, obj->type, &other_obj))
        {
            dereference(other_obj);
            return ERR_EXISTS;
        }
    }

    obj->ref_count = 1;
    obj->open_count = 0;
    obj->owner = get_current_uid();
    list_init(&obj->acl);

    access_control_entry_t *ace = malloc(sizeof(access_control_entry_t));
    ace->uid = get_current_uid();
    ace->access_mask = FULL_ACCESS;
    list_append(&obj->acl, &ace->link);

    lock_acquire(&obj_lock);
    if (obj->name) list_append(&named_objects[get_name_hash(obj->name)], &obj->by_name_list);
    else list_append(&anonymous_objects, &obj->by_name_list);
    list_append(&objects_by_type[obj->type], &obj->by_type_list);
    lock_release(&obj_lock);

    return ERR_SUCCESS;
}

dword_t open_object(object_t *obj, access_flags_t access_flags, handle_t *handle)
{
    dword_t ret = ERR_SUCCESS;
    lock_acquire(&obj_lock);

    if (!access_check(obj, access_flags))
    {
        ret = ERR_FORBIDDEN;
        goto done;
    }

    handle_t new_handle = insert_object(obj, access_flags);
    if (new_handle == INVALID_HANDLE)
    {
        ret = ERR_NOMEMORY;
        goto done;
    }

    *handle = new_handle;
    obj->ref_count++;
    obj->open_count++;

done:
    lock_release(&obj_lock);
    return ret;
}

dword_t open_object_by_name(const char *name, object_type_t type, access_flags_t access_flags, handle_t *handle)
{
    dword_t ret = ERR_SUCCESS;
    object_t *obj;

    if (!reference_by_name(name, type, &obj)) return ERR_NOTFOUND;
    lock_acquire(&obj_lock);

    if (!access_check(obj, access_flags))
    {
        ret = ERR_FORBIDDEN;
        goto done;
    }

    handle_t new_handle = insert_object(obj, access_flags);
    if (new_handle == INVALID_HANDLE)
    {
        ret = ERR_NOMEMORY;
        goto done;
    }

    *handle = new_handle;
    obj->open_count++;

done:
    lock_release(&obj_lock);
    if (ret != ERR_SUCCESS) dereference(obj);
    return ret;
}

void close_object_internal(object_t *obj)
{
    lock_acquire(&obj_lock);

    obj->open_count--;
    qword_t ref_count = --obj->ref_count;

    if (!ref_count)
    {
        ASSERT(obj->open_count == 0);
        list_remove(&obj->by_name_list);
        list_remove(&obj->by_type_list);
    }

    lock_release(&obj_lock);

    if (!ref_count)
    {
        if (type_info[obj->type].cleanup != NULL) type_info[obj->type].cleanup(obj);
        free(obj);
    }
}

sysret_t syscall_close_object(handle_t handle)
{
    dword_t ret = ERR_SUCCESS;
    process_t *proc = get_previous_mode() == USER_MODE ? get_current_process() : kernel_process;

    reference(&proc->header);
    lock_acquire(&proc->handle_table_lock);

    if (handle >= proc->handle_table_size)
    {
        lock_release(&proc->handle_table_lock);
        dereference(&proc->header);
        return ERR_NOTFOUND;
    }

    object_t *obj = proc->handle_table[handle].obj;

    if (obj != NULL)
    {
        close_object_internal(obj);

        proc->handle_table[handle].obj = NULL;
        proc->handle_count--;
    }
    else
    {
        ret = ERR_NOTFOUND;
    }

    lock_release(&proc->handle_table_lock);
    dereference(&proc->header);

    return ret;
}

sysret_t syscall_query_handle(handle_t handle, handle_info_type_t type, void *buffer, size_t size)
{
    dword_t ret = ERR_SUCCESS;
    process_t *proc;
    void *safe_buffer = NULL;

    if (get_previous_mode() == USER_MODE)
    {
        proc = get_current_process();
        if (!check_usermode(buffer, size)) return ERR_BADPTR;
        safe_buffer = malloc(size);
        if (safe_buffer == NULL) return ERR_NOMEMORY;
        memset(safe_buffer, 0, size);
    }
    else
    {
        proc = kernel_process;
        safe_buffer = buffer;
    }

    reference(&proc->header);
    lock_acquire(&proc->handle_table_lock);

    object_t *obj = proc->handle_table[handle].obj;
    if (obj == NULL)
    {
        ret = ERR_NOTFOUND;
        goto cleanup;
    }

    char *name = obj->name ? obj->name : "<anonymous>";

    switch (type)
    {
    case HANDLE_INFO_NAME:
        strncpy(safe_buffer, name, size);
        if (size < strlen(name) + 1) ret = ERR_SMALLBUF;
        break;

    case HANDLE_INFO_TYPE:
        *((object_type_t*)safe_buffer) = obj->type;
        if (size < sizeof(object_type_t)) ret = ERR_SMALLBUF;
        break;

    default:
        ret = ERR_INVALID;
    }

    if (get_previous_mode() == USER_MODE)
    {
        EH_TRY memcpy(buffer, safe_buffer, size);
        EH_CATCH ret = ERR_BADPTR;
        EH_DONE;
    }

cleanup:
    lock_release(&proc->handle_table_lock);
    dereference(&proc->header);
    if (get_previous_mode() == USER_MODE) free(safe_buffer);
    return ret;
}

sysret_t syscall_duplicate_handle(handle_t source_process, handle_t handle, handle_t dest_process, handle_t *duplicate)
{
    process_t *proc;
    handle_t safe_handle;

    if (get_previous_mode() == USER_MODE && !check_usermode(duplicate, sizeof(handle_t)))
    {
        return ERR_BADPTR;
    }

    if (source_process != INVALID_HANDLE)
    {
        if (!reference_by_handle(source_process, OBJECT_PROCESS, (object_t**)&proc)) return ERR_INVALID;
    }
    else
    {
        proc = get_previous_mode() == USER_MODE ? get_current_process() : kernel_process;
        reference(&proc->header);
    }

    lock_acquire_shared(&proc->handle_table_lock);

    if (handle >= proc->handle_table_size)
    {
        lock_release(&proc->handle_table_lock);
        dereference(&proc->header);
        return ERR_INVALID;
    }

    object_t *obj = proc->handle_table[handle].obj;
    access_flags_t access_flags = proc->handle_table[handle].access_flags;
    if (obj == NULL)
    {
        lock_release(&proc->handle_table_lock);
        dereference(&proc->header);
        return ERR_INVALID;
    }

    reference(obj);

    lock_release(&proc->handle_table_lock);
    dereference(&proc->header);

    if (dest_process != INVALID_HANDLE)
    {
        if (!reference_by_handle(source_process, OBJECT_PROCESS, (object_t**)&proc))
        {
            dereference(obj);
            return ERR_INVALID;
        }
    }
    else
    {
        proc = get_previous_mode() == USER_MODE ? get_current_process() : kernel_process;
        reference(&proc->header);
    }

    process_t *old_process = switch_process(proc);
    dword_t ret = open_object(obj, access_flags, &safe_handle);
    switch_process(old_process);

    dereference(&proc->header);
    dereference(obj);

    EH_TRY
    {
        *duplicate = safe_handle;
    }
    EH_CATCH
    {
        ret = ERR_BADPTR;
    }
    EH_DONE;

    return ret;
}

dword_t enum_objects_by_type(object_type_t type, object_t **object)
{
    dword_t ret = ERR_SUCCESS;
    list_entry_t *ptr;
    object_t *previous = *object;

    lock_acquire(&obj_lock);

    if (previous == NULL) ptr = objects_by_type[type].next;
    else ptr = previous->by_type_list.next;

    if (ptr != &objects_by_type[type])
    {
        *object = CONTAINER_OF(ptr, object_t, by_type_list);
        ret = ERR_SUCCESS;
    }
    else
    {
        *object = NULL;
        ret = ERR_NOMORE;
    }

    lock_release(&obj_lock);

    if (*object) reference(*object);
    if (previous) dereference(previous);
    return ret;
}

static sysret_t wait_for_objects(const handle_t *handles, void * const *parameters, size_t count, timeout_t timeout, wait_condition_type_t condition_type)
{
    dword_t ret = ERR_SUCCESS;
    object_t **objects = NULL;
    wait_condition_t *condition = NULL;
    const handle_t *safe_handles;
    void * const *safe_parameters = NULL;
    processor_mode_t previous_mode = get_previous_mode();
    if (count == 0) return ERR_INVALID;

    if (previous_mode == USER_MODE)
    {
        if (!check_usermode(handles, count * sizeof(handle_t))) return ERR_BADPTR;
        if (parameters && !check_usermode(parameters, count * sizeof(void*))) return ERR_BADPTR;

        safe_handles = calloc(count, sizeof(handle_t));
        if (safe_handles == NULL) return ERR_NOMEMORY;

        if (parameters)
        {
            safe_parameters = calloc(count, sizeof(void*));
            if (safe_parameters == NULL)
            {
                free((void*)safe_handles);
                return ERR_NOMEMORY;
            }
        }

        EH_TRY
        {
            memcpy((handle_t*)safe_handles, handles, count * sizeof(handle_t));
            if (safe_parameters) memcpy((void*)safe_parameters, parameters, count * sizeof(void*));
        }
        EH_CATCH
        {
            free((void*)safe_handles);
            if (safe_parameters) free((void*)safe_parameters);
            EH_ESCAPE(return ERR_BADPTR);
        }
        EH_DONE;
    }
    else
    {
        safe_handles = handles;
        safe_parameters = parameters;
    }

    if (!(objects = calloc(count, sizeof(object_t*))))
    {
        ret = ERR_NOMEMORY;
        goto cleanup;
    }

    size_t i;
    for (i = 0; i < count; i++)
    {
        if (!reference_by_handle(safe_handles[i], OBJECT_ANY_TYPE, &objects[i]))
        {
            ret = ERR_INVALID;
            goto cleanup;
        }
    }

    if (!(condition = malloc(sizeof(wait_condition_t) + count * sizeof(wait_condition_t*))))
    {
        ret = ERR_NOMEMORY;
        goto cleanup;
    }

    memset(condition, 0, sizeof(wait_condition_t) + count * sizeof(wait_condition_t*));
    condition->type = condition_type;

    for (i = 0; i < count; i++)
    {
        if (!type_info[objects[i]->type].pre_wait)
        {
            ret = ERR_INVALID;
            goto cleanup;
        }

        wait_condition_t *cond = malloc(sizeof(wait_condition_t));
        if (!cond)
        {
            ret = ERR_NOMEMORY;
            goto cleanup;
        }

        ret = type_info[objects[i]->type].pre_wait(objects[i], safe_parameters[i], cond);
        if (ret != ERR_SUCCESS)
        {
            free(cond);
            goto cleanup;
        }

        condition->conditions[i] = cond;
    }

    wait_result_t result = scheduler_wait(condition, timeout);
    if (result == WAIT_CANCELED) ret = ERR_CANCELED;
    else if (result == WAIT_TIMED_OUT) ret = ERR_TIMEOUT;

    for (i = 0; i < count; i++)
    {
        if (type_info[objects[i]->type].post_wait)
        {
            type_info[objects[i]->type].post_wait(objects[i], safe_parameters ? safe_parameters[i] : NULL, result);
        }
    }

cleanup:
    if (condition)
    {
        for (i = 0; i < count; i++) if (condition->conditions[i]) free(condition->conditions[i]);
        free(condition);
    }

    if (objects)
    {
        for (i = 0; i < count; i++) if (objects[i]) dereference(objects[i]);
        free(objects);
    }

    if (previous_mode == USER_MODE)
    {
        free((void*)safe_handles);
        if (safe_parameters) free((void*)safe_parameters);
    }

    return ret;
}

sysret_t syscall_wait_for_one(handle_t handle, void *parameter, timeout_t timeout)
{
    dword_t ret = ERR_SUCCESS;

    object_t *object;
    if (!reference_by_handle(handle, OBJECT_ANY_TYPE, &object)) return ERR_INVALID;

    if (!type_info[object->type].pre_wait)
    {
        dereference(object);
        return ERR_INVALID;
    }

    wait_condition_t condition;
    ret = type_info[object->type].pre_wait(object, parameter, &condition);
    if (ret != ERR_SUCCESS) return ret;

    wait_result_t result = scheduler_wait(&condition, timeout);
    dereference(object);

    switch (result)
    {
    case WAIT_CONDITION_HIT:
        return ERR_SUCCESS;
    case WAIT_CANCELED:
        return ERR_CANCELED;
    case WAIT_TIMED_OUT:
        return ERR_TIMEOUT;

    default:
        KERNEL_CRASH("Unexpected scheduler wait result");
        return ERR_INVALID;
    }
}

sysret_t syscall_wait_for_any(const handle_t *handles, void * const *parameters, size_t count, timeout_t timeout)
{
    return wait_for_objects(handles, parameters, count, timeout, WAIT_GROUP_ANY);
}

sysret_t syscall_wait_for_all(const handle_t *handles, void * const *parameters, size_t count, timeout_t timeout)
{
     return wait_for_objects(handles, parameters, count, timeout, WAIT_GROUP_ALL);
}

void object_init(void)
{
    list_init_array(named_objects, sizeof(named_objects) / sizeof(*named_objects));
    list_init_array(objects_by_type, sizeof(objects_by_type) / sizeof(*objects_by_type));
}
